import multiprocessing
import math
import cv2 as cv
import keras.backend as K
import numpy as np
from tensorflow.python.client import device_lib

from config import epsilon, epsilon_sqr
from config import img_cols
from config import img_rows
from config import unknown_code

from tensorflow.python.debug.lib.debug_data import InconvertibleTensorProto
import tensorflow as tf


def L1_loss(y_true, y_pred):
    return K.mean(K.abs(y_true - y_pred))


def L2_loss(y_true, y_pred):
    return K.mean(K.square(y_true - y_pred))


def edge_loss(y_true, y_pred):
    assert False # not implemented

# getting the number of GPUs
def get_available_gpus():
    local_device_protos = device_lib.list_local_devices()
    return [x.name for x in local_device_protos if x.device_type == 'GPU']


# getting the number of CPUs
def get_available_cpus():
    return multiprocessing.cpu_count()


def get_final_output(out, trimap):
    mask = np.equal(trimap, unknown_code).astype(np.float32)
    return (1 - mask) * trimap + mask * out

def patch_dims(mat_size, patch_size):
    return np.ceil(np.array(mat_size) / patch_size).astype(int)

def create_patches(mat, patch_size):
    mat_size = mat.shape
    assert len(mat_size) == 3, "Input mat need to have 4 channels (R, G, B, trimap)"
    assert mat_size[-1] == 4 , "Input mat need to have 4 channels (R, G, B, trimap)"

    patches_dim = patch_dims(mat_size=mat_size[:2], patch_size=patch_size)
    patches_count = np.product(patches_dim)

    patches = np.zeros(shape=(patches_count, patch_size, patch_size, 4), dtype=np.float32)
    for y in range(patches_dim[0]):
        y_start = y * patch_size
        for x in range(patches_dim[1]):
            x_start = x * patch_size

            # extract patch from input mat
            single_patch = mat[y_start: y_start + patch_size, x_start: x_start + patch_size, :]

            # zero pad patch in bottom and right side if real patch size is smaller than patch size
            real_patch_h, real_patch_w = single_patch.shape[:2]
            patch_id = y + x * patches_dim[0]
            patches[patch_id, :real_patch_h, :real_patch_w, :] = single_patch

    return patches

def assemble_patches(pred_patches, mat_size, patch_size):
    patch_dim_h, patch_dim_w = patch_dims(mat_size=mat_size, patch_size=patch_size)
    result = np.zeros(shape=(patch_size * patch_dim_h, patch_size * patch_dim_w), dtype=np.uint8)
    patches_count = pred_patches.shape[0]

    for i in range(patches_count):
        y = (i % patch_dim_h) * patch_size
        x = int(math.floor(i / patch_dim_h)) * patch_size

        result[y:y+patch_size, x:x+patch_size] = pred_patches[i]

    return result

def safe_crop(mat, x, y, crop_size=(img_rows, img_cols)):
    crop_height, crop_width = crop_size
    if len(mat.shape) == 2:
        ret = np.zeros((crop_height, crop_width), np.float32)
    else:
        ret = np.zeros((crop_height, crop_width, 3), np.float32)
    crop = mat[y:y + crop_height, x:x + crop_width]
    h, w = crop.shape[:2]
    ret[0:h, 0:w] = crop
    if crop_size != (img_rows, img_cols):
        ret = cv.resize(ret, dsize=(img_rows, img_cols), interpolation=cv.INTER_NEAREST)
    return ret


def draw_str(dst, target, s):
    x, y = target
    cv.putText(dst, s, (x + 1, y + 1), cv.FONT_HERSHEY_PLAIN, 1.0, (0, 0, 0), thickness=2, lineType=cv.LINE_AA)
    cv.putText(dst, s, (x, y), cv.FONT_HERSHEY_PLAIN, 1.0, (255, 255, 255), lineType=cv.LINE_AA)


# for debug use
def callable_check_loss(datum, tensor):
    if isinstance(tensor, InconvertibleTensorProto):
        # Uninitialized tensor doesn't have bad numerical values.
        # Also return False for data types that cannot be represented as numpy
        # arrays.
        return False
    elif (np.issubdtype(tensor.dtype, np.floating) or
          np.issubdtype(tensor.dtype, np.complex) or
          np.issubdtype(tensor.dtype, np.integer)):
        if datum._node_name.startswith('loss/refinement_pred_loss/Mean_3'):
            return np.any(np.isnan(tensor)) or np.any(np.isinf(tensor))
        else:
            return False
    else:
        return False